import torch
import torch.nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from sklearn.model_selection import train_test_split

class GIN(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim, num_layers):
        super(GIN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for i in range(num_layers):
            if i == 0:
                nn = torch.nn.Sequential(torch.nn.Linear(num_features, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
            else:
                nn = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
            conv = GINConv(nn, train_eps=True)
            bn = torch.nn.BatchNorm1d(hidden_dim)

            self.convs.append(conv)
            self.batch_norms.append(bn)

        self.linear = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = bn(x)

        x = global_add_pool(x, batch)
        x = self.linear(x)
        return F.log_softmax(x, dim=-1)
    

# 加载数据集（以MUTAG为例）
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# 初始化模型
model = GIN(num_features=dataset.num_features, num_classes=dataset.num_classes, hidden_dim=128, num_layers=3)
optimizer = Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_dataset)

def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

# 训练模型
for epoch in range(100):
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

# torch.save(model.state_dict(), './gin_model.pth')